{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-5Qo12uLFZ98"
      },
      "outputs": [],
      "source": [
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Load Kaggle API Token\n",
        "from google.colab import files\n",
        "files.upload()\n",
        "!pip install -q kaggle\n",
        "!mkdir ~/.kaggle\n",
        "!cp kaggle.json ~/.kaggle/\n",
        "!chmod 600 ~/.kaggle/kaggle.json\n",
        "!kaggle datasets download lifailifai/ucf101"
      ],
      "metadata": {
        "id": "rJnwf-gE5-z8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DiSt42g4SCik"
      },
      "source": [
        "# Install"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0DKusffxPOfj"
      },
      "outputs": [],
      "source": [
        "!pip install rudalle==0.0.1rc4"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install transformers"
      ],
      "metadata": {
        "id": "IJNA_Yui7DzX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!gdown https://drive.google.com/uc?id=1_B9Y2U9d-xIO1pKb4J9qsVltMCH_ZNOy"
      ],
      "metadata": {
        "id": "PdCnU__h8eWs"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "frQ29AqXSVcb"
      },
      "source": [
        "# Unzip data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g2j_g_T7wiQd"
      },
      "outputs": [],
      "source": [
        "#@markdown Lets download data\n",
        "!unzip ucf101.zip"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Import"
      ],
      "metadata": {
        "id": "VXo4F17k6Voz"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tJ9LJ6rB0K7-"
      },
      "outputs": [],
      "source": [
        "import io\n",
        "import os\n",
        "import PIL\n",
        "import random\n",
        "import numpy as np\n",
        "import torch\n",
        "import torchvision\n",
        "import transformers\n",
        "import more_itertools\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from tqdm import tqdm\n",
        "import pandas as pd\n",
        "from torch.utils.data import Dataset\n",
        "from tqdm import tqdm\n",
        "from dataclasses import dataclass, field\n",
        "import torchvision.transforms as T\n",
        "import torchvision.transforms.functional as TF\n",
        "import cv2\n",
        "from PIL import Image\n",
        "from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip\n",
        "from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip\n",
        "from rudalle.utils import seed_everything"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load ruDALLe"
      ],
      "metadata": {
        "id": "Jqpk_rYT6e36"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4v1KNh9kFEwo"
      },
      "outputs": [],
      "source": [
        "device = 'cuda'\n",
        "model = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device)\n",
        "vae = get_vae().to(device)\n",
        "tokenizer = get_tokenizer()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Args"
      ],
      "metadata": {
        "id": "Gk31mh1L6oK4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class Args():\n",
        "    def __init__(self):\n",
        "     \n",
        "        self.text_seq_length = model.get_param('text_seq_length')\n",
        "        self.total_seq_length = model.get_param('total_seq_length')\n",
        "        self.epochs = 1\n",
        "        self.save_path='checkpoints/'\n",
        "        self.model_name = 'awesomemodel_'\n",
        "        self.save_every = 2000\n",
        "        self.prefix_length = 10\n",
        "        self.clip = 0.24\n",
        "        self.lr = 2e-5\n",
        "        self.warmup_steps =50\n",
        "args = Args()\n",
        "if not os.path.exists(args.save_path):\n",
        "        os.makedirs(args.save_path)"
      ],
      "metadata": {
        "id": "ThTEtqfy6ong"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Dataset"
      ],
      "metadata": {
        "id": "U4TUXpTH6wAr"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BJE5rkglgfJI"
      },
      "outputs": [],
      "source": [
        "def get_all_video_path(dir):\n",
        "    paths = []\n",
        "    for folder in os.listdir(dir):\n",
        "        new_path = os.path.join(dir, folder)\n",
        "        for file_name in os.listdir(new_path):\n",
        "            paths.append(os.path.join(new_path, file_name))\n",
        "    return paths"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xFgZe6I7MqL7"
      },
      "outputs": [],
      "source": [
        "def read_video(path, transform=None, frames_num=None):\n",
        "    frames = []\n",
        "    cap = cv2.VideoCapture(path)\n",
        "    while(cap.isOpened()):\n",
        "        ret, frame = cap.read()\n",
        "        \n",
        "        if ret:\n",
        "            frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n",
        "            if transform is not None:\n",
        "                frame = transform(frame)\n",
        "            frames.append(frame)\n",
        "            if frames_num is not None:\n",
        "                if len(frames) >= frames_num:\n",
        "                    break\n",
        "        else:\n",
        "            break\n",
        "    cap.release()\n",
        "    return torch.stack(frames)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IN6IHOA7MYxq"
      },
      "outputs": [],
      "source": [
        "class RuDalleVideoDataset(Dataset):\n",
        "    def __init__(\n",
        "            self,\n",
        "            dir_path,\n",
        "            csv_path,\n",
        "            tokenizer,\n",
        "    ):\n",
        "        \"\"\" tokenizer - объект с методами tokenizer_wrapper.BaseTokenizerWrapper \"\"\"\n",
        "        self.df = pd.read_csv(csv_path)#'/content/drive/MyDrive/ucf100_ru.csv'\n",
        "        self.df.index = self.df['en'].values\n",
        "        self.paths = get_all_video_path(dir_path)\n",
        "        self.text_seq_length = model.get_param('text_seq_length')\n",
        "        self.tokenizer = tokenizer\n",
        "        self.image_size = 128\n",
        "\n",
        "        self.image_transform = T.Compose([\n",
        "                T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n",
        "                T.RandomResizedCrop(self.image_size,\n",
        "                                    scale=(1., 1.), # в train было scale=(0.75., 1.),\n",
        "                                    ratio=(1., 1.)),\n",
        "                T.ToTensor()\n",
        "            ])\n",
        "\n",
        "    \n",
        "    def __len__(self):\n",
        "        return len(self.paths)\n",
        "\n",
        "    def __getitem__(self, item):\n",
        "        path = self.paths[item]\n",
        "        ind_of_text = path.split('/')[3]\n",
        "        try:\n",
        "          video = read_video(path, self.image_transform, 50)\n",
        "          if len(video) < 50:\n",
        "              print('Video is shorter than 50 frames')\n",
        "              self.__getitem__(random.randint(0, len(self.paths) - 1))\n",
        "          #new_video = torch.stack([video[0], video[7], video[14], video[21]], dim=0)\n",
        "          #new_video = torch.stack([video[0], video[3], video[7], video[11]], dim=0)\n",
        "          new_video = torch.stack([video[0], video[16], video[32], video[48]], dim=0)\n",
        "        except Exception as err:  # noqa\n",
        "            print(err)\n",
        "            return self.__getitem__(random.randint(0, len(self.paths) - 1))\n",
        "        ru_text = self.df.loc[ind_of_text]['ru'].lower().capitalize()\n",
        "        \n",
        "        text = self.tokenizer.encode_text(ru_text, text_seq_length=self.text_seq_length).squeeze(0)\n",
        "        return text, new_video"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5Ms0yhmCASvo"
      },
      "outputs": [],
      "source": [
        "from torch.utils.data import Dataset, DataLoader\n",
        "st = RuDalleVideoDataset(dir_path='/content/UCF-101', csv_path='ucf100_ru2.csv', tokenizer=tokenizer)\n",
        "train_dataloader = DataLoader(st, batch_size=1, shuffle=True, drop_last=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Train functions"
      ],
      "metadata": {
        "id": "sVtDeJzg65iI"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jMc-jGjDs5J7"
      },
      "outputs": [],
      "source": [
        "def freeze(\n",
        "    model,\n",
        "    freeze_emb=True,\n",
        "    freeze_ln=False,\n",
        "    freeze_attn=False,\n",
        "    freeze_ff=True,\n",
        "    freeze_other=True,\n",
        "):\n",
        "    for name, p in model.module.named_parameters():\n",
        "        name = name.lower()\n",
        "        if 'ln' in name or 'norm' in name:\n",
        "            p.requires_grad = not freeze_ln\n",
        "        elif 'embeddings' in name:\n",
        "            p.requires_grad = not freeze_emb\n",
        "        elif 'mlp' in name:\n",
        "            p.requires_grad = not freeze_ff\n",
        "        elif 'attn' in name:\n",
        "            p.requires_grad = not freeze_attn\n",
        "        else:\n",
        "            p.requires_grad = not freeze_other\n",
        "    return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-hVTQcDF1EbM"
      },
      "outputs": [],
      "source": [
        "#markdown Simple training loop\n",
        "def train(model,args, train_dataloader, device):\n",
        "    loss_logs = []\n",
        "    progress = tqdm(total=len(train_dataloader), desc='finetuning')\n",
        "    save_counter = 0\n",
        "    device = model.get_param('device')\n",
        "    for epoch in range(args.epochs):\n",
        "      \n",
        "      for text, images in train_dataloader:\n",
        "        text = text.to(device)\n",
        "        images = images.to(device)\n",
        "        save_counter+=1\n",
        "\n",
        "        model.zero_grad()\n",
        "        attention_mask = torch.tril(torch.ones((1, 1, args.total_seq_length, args.total_seq_length), device=device))\n",
        "        with torch.no_grad():\n",
        "            image_input_ids = vae.get_codebook_indices(images[0]).flatten().unsqueeze(0)\n",
        "        #print(text.shape, image_input_ids.shape)\n",
        "        input_ids = torch.cat((text, image_input_ids), dim=1) \n",
        "        loss, loss_values = model.forward(input_ids, attention_mask, return_loss=True)\n",
        "        #train step\n",
        "        loss.backward()\n",
        "            \n",
        "        torch.nn.utils.clip_grad_norm_(model.parameters(),args.clip)\n",
        "        optimizer.step()\n",
        "        scheduler.step()\n",
        "        optimizer.zero_grad()\n",
        "        #save every here\n",
        "        if save_counter % args.save_every == 0:\n",
        "          print(f'Saveing checkpoint here {args.model_name}_dalle_{save_counter}.pt')\n",
        "          \n",
        "          plt.plot(loss_logs)\n",
        "          plt.show()\n",
        "          torch.save(\n",
        "                    model.state_dict(),\n",
        "                    os.path.join(args.save_path,f\"{args.model_name}_dalle_{save_counter}.pt\")\n",
        "                    )\n",
        "\n",
        "        loss_logs+=[loss.item()]\n",
        "        progress.update()\n",
        "        progress.set_postfix({\"loss\": loss.item()})\n",
        "\n",
        "    print(f'Complitly tuned and saved here  {args.model_name}__dalle_last.pt')\n",
        "    \n",
        "    plt.plot(loss_logs)\n",
        "    plt.show()\n",
        "    \n",
        "    torch.save(\n",
        "                model.state_dict(),\n",
        "                'videodalle_new.pt'#os.path.join(args.save_path,f\"{args.model_name}_dalle_last.pt\")\n",
        "                )"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Train"
      ],
      "metadata": {
        "id": "BAAg315r9KEq"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-4GDi4BjHiHr"
      },
      "outputs": [],
      "source": [
        "from transformers import  AdamW, get_linear_schedule_with_warmup\n",
        "model.train()\n",
        "optimizer = AdamW(model.parameters(), lr = args.lr)\n",
        "\n",
        "scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, \n",
        "                                                final_div_factor=500,  \n",
        "                                                steps_per_epoch=len(train_dataloader), epochs=args.epochs )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dGPS3HnDN3wO"
      },
      "outputs": [],
      "source": [
        "#@markdown You can unfreeze or freeze more parametrs, but it can \n",
        "model = freeze(model = model,\n",
        "    freeze_emb=False,\n",
        "    freeze_ln=False,\n",
        "    freeze_attn=False,\n",
        "    freeze_ff=True,\n",
        "    freeze_other=False)#freeze params to \n",
        "\n",
        "train(model, args, train_dataloader, device)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "background_execution": "on",
      "collapsed_sections": [
        "DiSt42g4SCik",
        "frQ29AqXSVcb",
        "VXo4F17k6Voz",
        "Jqpk_rYT6e36",
        "Gk31mh1L6oK4",
        "U4TUXpTH6wAr",
        "sVtDeJzg65iI",
        "BAAg315r9KEq"
      ],
      "machine_shape": "hm",
      "name": "videodalle_finetune.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
